/*
* Copyright 1999-2002 Carnegie Mellon University.
* Portions Copyright 2002 Sun Microsystems, Inc.
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* All Rights Reserved.  Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/

package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;

import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.acoustic.HMMState;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.*;
import edu.cmu.sphinx.util.LogMath;

import java.io.IOException;
import java.util.HashMap;
import java.util.logging.Logger;

/** Manages the HMM pools. */
class HMMPoolManager {

    private HMMManager hmmManager;
    private HashMap<Object, Integer> indexMap;
    private Pool<float[]> meansPool;
    private Pool<float[]> variancePool;
    private Pool<float[][]> matrixPool;
    private GaussianWeights mixtureWeights;
    
    private Pool<Buffer> meansBufferPool;
    private Pool<Buffer> varianceBufferPool;
    private Pool<Buffer[]> matrixBufferPool;
    private Pool<Buffer> mixtureWeightsBufferPool;

    private Pool<Senone> senonePool;
    private LogMath logMath;

    private float logMixtureWeightFloor;
    private float logTransitionProbabilityFloor;
    private float varianceFloor;
    private float logLikelihood;
    private float currentLogLikelihood;

    /** The logger for this class */
    private static Logger logger = Logger.getLogger("edu.cmu.sphinx.linguist.acoustic.HMMPoolManager");

    /**
     * Constructor for this pool manager. It gets the pointers to the pools from a loader.
     *
     * @param loader the loader
     * @throws IOException 
     */
    protected HMMPoolManager(Loader loader) throws IOException {
    	loader.load();
        hmmManager = loader.getHMMManager();
        indexMap = new HashMap<Object, Integer>();
        meansPool = loader.getMeansPool();
        variancePool = loader.getVariancePool();
        mixtureWeights = loader.getMixtureWeights();
        matrixPool = loader.getTransitionMatrixPool();
        senonePool = loader.getSenonePool();

//	logMath = LogMath.getLogMath();
//        float mixtureWeightFloor =
//	    props.getFloat(TiedStateAcousticModel.PROP_MW_FLOOR);
//	logMixtureWeightFloor = logMath.linearToLog(mixtureWeightFloor);
//        float transitionProbabilityFloor =
//	    props.getFloat(TiedStateAcousticModel.PROP_TP_FLOOR);
//	logTransitionProbabilityFloor =
//	    logMath.linearToLog(transitionProbabilityFloor);
//        varianceFloor =
//	    props.getFloat(TiedStateAcousticModel.PROP_VARIANCE_FLOOR);

        createBuffers();
        logLikelihood = 0.0f;
        logMath = LogMath.getLogMath();
    }

    /** Recreates the buffers. */
    protected void resetBuffers() {
        createBuffers();
        logLikelihood = 0.0f;
    }

    /** Create buffers for all pools used by the trainer in this pool manager. */
    protected void createBuffers() {
        // the option false or true refers to whether the buffer is in
        // log scale or not, true if it is.
        meansBufferPool = create1DPoolBuffer(meansPool, false);
        varianceBufferPool = create1DPoolBuffer(variancePool, false);
        matrixBufferPool = create2DPoolBuffer(matrixPool, true);
        mixtureWeightsBufferPool = createWeightsPoolBuffer(mixtureWeights);
    }


    /** Create buffers for a given pool. */
    private Pool<Buffer> create1DPoolBuffer(Pool<float[]> pool, boolean isLog) {
        Pool<Buffer> bufferPool = new Pool<Buffer>(pool.getName());

        for (int i = 0; i < pool.size(); i++) {
            float[] element = pool.get(i);
            indexMap.put(element, i);
            Buffer buffer = new Buffer(element.length, isLog, i);
            bufferPool.put(i, buffer);
        }
        return bufferPool;
    }
    
    private Pool<Buffer> createWeightsPoolBuffer(GaussianWeights mixtureWeights) {
         Pool<Buffer> bufferPool = new Pool<Buffer>(mixtureWeights.getName());
         int statesNum = mixtureWeights.getStatesNum();
         int streamsNum = mixtureWeights.getStreamsNum();
         int gauPerState = mixtureWeights.getGauPerState();
         for (int i = 0; i < streamsNum; i++) {
             for (int j = 0; j < statesNum; j++) {
                 int id = i * statesNum + j;
                 Buffer buffer = new Buffer(gauPerState, true, id);
                 bufferPool.put(id, buffer);
             }
         }
         return bufferPool;
    }
    
    /** Create buffers for a given pool. */
    private Pool<Buffer[]> create2DPoolBuffer(Pool<float[][]> pool, boolean isLog) {
        Pool<Buffer[]> bufferPool = new Pool<Buffer[]>(pool.getName());

        for (int i = 0; i < pool.size(); i++) {
            float[][] element = pool.get(i);
            indexMap.put(element, i);
            int poolSize = element.length;
            Buffer[] bufferArray = new Buffer[poolSize];
            for (int j = 0; j < poolSize; j++) {
                bufferArray[j] = new Buffer(element[j].length, isLog, j);
            }
            bufferPool.put(i, bufferArray);
        }
        return bufferPool;
    }

    /**
     * Accumulate the TrainerScore into the buffers.
     *
     * @param index the current index into the TrainerScore vector
     * @param score the TrainerScore
     */
    protected void accumulate(int index, TrainerScore[] score) {
        accumulate(index, score, null);
    }

    /**
     * Accumulate the TrainerScore into the buffers.
     *
     * @param index     the current index into the TrainerScore vector
     * @param score     the TrainerScore for the current frame
     * @param nextScore the TrainerScore for the next time frame
     */
    protected void accumulate(int index, TrainerScore[] score, TrainerScore[] nextScore) {
        int senoneID;
        TrainerScore thisScore = score[index];

        // We should be doing this just once per utterance...
        // currentLogLikelihood = thisScore.getLogLikelihood();

        // Since we're scaling, the loglikelihood disappears...
        currentLogLikelihood = 0;
        // And the total becomes the sum of (-) scaling factors
        logLikelihood -= score[0].getScalingFactor();

        SenoneHMMState state = (SenoneHMMState) thisScore.getState();
        if (state == null) {
            // We only care about the case "all models"
            senoneID = thisScore.getSenoneID();
            if (senoneID == TrainerAcousticModel.ALL_MODELS) {
                accumulateMean(senoneID, score[index]);
                accumulateVariance(senoneID, score[index]);
                accumulateMixture(senoneID, score[index]);
                accumulateTransition(senoneID, index, score, nextScore);
            }
        } else {
            // If state is non-emitting, we presume there's only one
            // transition out of it. Therefore, we only accumulate
            // data for emitting states.
            if (state.isEmitting()) {
                senoneID = senonePool.indexOf(state.getSenone());
                // accumulateMean(senoneID, score[index]);
                // accumulateVariance(senoneID, score[index]);
                accumulateMixture(senoneID, score[index]);
                accumulateTransition(senoneID, index, score, nextScore);
            }
        }
    }

    /** Accumulate the means. */
    private void accumulateMean(int senone, TrainerScore score) {
        if (senone == TrainerAcousticModel.ALL_MODELS) {
            for (int i = 0; i < senonePool.size(); i++) {
                accumulateMean(i, score);
            }
        } else {
            GaussianMixture gaussian = (GaussianMixture)senonePool.get(senone);
            MixtureComponent[] mix = gaussian.getMixtureComponents();
            for (int i = 0; i < mix.length; i++) {
                float[] mean = mix[i].getMean();
                // int indexMean = meansPool.indexOf(mean);
                int indexMean = indexMap.get(mean);
                assert indexMean >= 0;
                assert indexMean == senone;
                Buffer buffer = meansBufferPool.get(indexMean);
                float[] feature = ((FloatData) score.getData()).getValues();
                double[] data = new double[feature.length];
                float prob = score.getComponentGamma()[i];
                prob -= currentLogLikelihood;
                double dprob = logMath.logToLinear(prob);
                // prob = (float) logMath.logToLinear(prob);
                for (int j = 0; j < data.length; j++) {
                    data[j] = feature[j] * dprob;
                }
                buffer.accumulate(data, dprob);
            }
        }
    }


    /** Accumulate the variance. */
    private void accumulateVariance(int senone, TrainerScore score) {
        if (senone == TrainerAcousticModel.ALL_MODELS) {
            for (int i = 0; i < senonePool.size(); i++) {
                accumulateVariance(i, score);
            }
        } else {
            GaussianMixture gaussian = (GaussianMixture)senonePool.get(senone);
            MixtureComponent[] mix = gaussian.getMixtureComponents();
            for (int i = 0; i < mix.length; i++) {
                float[] mean = mix[i].getMean();
                float[] variance = mix[i].getVariance();
                // int indexVariance = variancePool.indexOf(variance);
                int indexVariance = indexMap.get(variance);
                Buffer buffer = varianceBufferPool.get(indexVariance);
                float[] feature = ((FloatData) score.getData()).getValues();
                double[] data = new double[feature.length];
                float prob = score.getComponentGamma()[i];
                prob -= currentLogLikelihood;
                double dprob = logMath.logToLinear(prob);
                for (int j = 0; j < data.length; j++) {
                    data[j] = (feature[j] - mean[j]);
                    data[j] *= data[j] * dprob;
                }
                buffer.accumulate(data, dprob);
            }
        }
    }

    /** Accumulate the mixture weights. */
    private void accumulateMixture(int senone, TrainerScore score) {
        // The index into the senone pool and the mixture weight pool
        // is the same
        if (senone == TrainerAcousticModel.ALL_MODELS) {
            for (int i = 0; i < senonePool.size(); i++) {
                accumulateMixture(i, score);
            }
        } else {
            Buffer buffer = mixtureWeightsBufferPool.get(senone);
            for (int i = 0; i < mixtureWeights.getGauPerState(); i++) {
                float prob = score.getComponentGamma()[i];
                prob -= currentLogLikelihood;
                buffer.logAccumulate(prob, i, logMath);
            }
        }
    }

    /**
     * Accumulate transitions from a given state.
     *
     * @param indexScore the current index into the TrainerScore
     * @param score      the score information
     * @param nextScore  the score information for the next frame
     */
    private void accumulateStateTransition(int indexScore, TrainerScore[] score, TrainerScore[] nextScore) {
        HMMState state = score[indexScore].getState();
        if (state == null) {
            // Non-emitting state
            return;
        }
        int indexState = state.getState();
        SenoneHMM hmm = (SenoneHMM) state.getHMM();
        float[][] matrix = hmm.getTransitionMatrix();

        // Find the index for current matrix in the transition matrix pool
        // int indexMatrix = matrixPool.indexOf(matrix);
        int indexMatrix = indexMap.get(matrix);

        // Find the corresponding buffer
        Buffer[] bufferArray = matrixBufferPool.get(indexMatrix);

        // Let's concentrate on the transitions *from* the current state
        float[] vector = matrix[indexState];

        for (int i = 0; i < vector.length; i++) {
            // Make sure this is a valid transition
            if (vector[i] != LogMath.LOG_ZERO) {

                // We're assuming that if the states have position "a"
                // and "b" in the HMM, they'll have positions "k+a"
                // and "k+b" in the graph, that is, their relative
                // position is the same.

                // Distance between current state and "to" state in
                // the HMM
                int dist = i - indexState;

                // "to" state in the graph
                int indexNextScore = indexScore + dist;

                // Make sure the next state is non-emitting (the last
                // in the HMM), or in the same HMM.
                assert ((nextScore[indexNextScore].getState() == null) ||
                        (nextScore[indexNextScore].getState().getHMM() == hmm));
                float alpha = score[indexScore].getAlpha();
                float beta = nextScore[indexNextScore].getBeta();
                float transitionProb = vector[i];
                float outputProb = nextScore[indexNextScore].getScore();
                float prob = alpha + beta + transitionProb + outputProb;
                prob -= currentLogLikelihood;
                // i is the index into the next state.
                bufferArray[indexState].logAccumulate(prob, i, logMath);
                /*
        if ((indexMatrix == 0) && (i == 2)) {
            //    	    System.out.println("Out: " + outputProb);
                //	    	    bufferArray[indexState].dump();
        }
            */
            }
        }
    }

    /**
     * Accumulate transitions from a given state.
     *
     * @param indexState the state index
     * @param hmm        the HMM
     * @param value      the value to accumulate
     */
    private void accumulateStateTransition(int indexState, SenoneHMM hmm, float value) {
        // Find the transition matrix in this hmm
        float[][] matrix = hmm.getTransitionMatrix();

        // Find the vector with transitions from the current state to
        // other states.
        float[] stateVector = matrix[indexState];

        // Find the index of the current transition matrix in the
        // transition matrix pool.
        // int indexMatrix = matrixPool.indexOf(matrix);
        int indexMatrix = indexMap.get(matrix);

        // Find the buffer for the transition matrix.
        Buffer[] bufferArray = matrixBufferPool.get(indexMatrix);

        // Accumulate for the transitions from current state
        for (int i = 0; i < stateVector.length; i++) {
            // Make sure we're not trying to accumulate in an invalid
            // transition.
            if (stateVector[i] != LogMath.LOG_ZERO) {
                bufferArray[indexState].logAccumulate(value, i, logMath);
            }
        }
    }

    /** Accumulate the transition probabilities. */
    private void accumulateTransition(int indexHmm, int indexScore, TrainerScore[] score, TrainerScore[] nextScore) {
        if (indexHmm == TrainerAcousticModel.ALL_MODELS) {
            // Well, special case... we want to add an amount to all
            // the states in all models
            for (HMM hmm : hmmManager) {
                for (int j = 0; j < hmm.getOrder(); j++) {
                    accumulateStateTransition(j, (SenoneHMM)hmm, score[indexScore].getScore());
                }
            }
        } else {
            // For transition accumulation, we don't consider the last
            // time frame, since there's no transition from there to
            // anywhere...
            if (nextScore != null) {
                accumulateStateTransition(indexScore, score, nextScore);
            }
        }
    }

    /** Update the log likelihood. This method should be called for every utterance. */
    protected void updateLogLikelihood() {
        // logLikelihood += currentLogLikelihood;
    }

    /**
     * Normalize the buffers.
     *
     * @return the log likelihood associated with the current training set
     */
    protected float normalize() {
        normalizePool(meansBufferPool);
        normalizePool(varianceBufferPool);
        logNormalizePool(mixtureWeightsBufferPool);
        logNormalize2DPool(matrixBufferPool, matrixPool);
        return logLikelihood;
    }

    /**
     * Normalize a single buffer pool.
     *
     * @param pool the buffer pool to normalize
     */
    private void normalizePool(Pool<Buffer> pool) {
        assert pool != null;
        for (int i = 0; i < pool.size(); i++) {
            Buffer buffer = pool.get(i);
            if (buffer.wasUsed()) {
                buffer.normalize();
            }
        }
    }

    /**
     * Normalize a single buffer pool in log scale.
     *
     * @param pool the buffer pool to normalize
     */
    private void logNormalizePool(Pool<Buffer> pool) {
        assert pool != null;
        for (int i = 0; i < pool.size(); i++) {
            Buffer buffer = pool.get(i);
            if (buffer.wasUsed()) {
                buffer.logNormalize();
            }
        }
    }

    /**
     * Normalize a 2D buffer pool in log scale. Typically, this is the case with the transition matrix, which also needs
     * a mask for values that are allowed, and therefor have to be updated, or not allowed, and should be ignored.
     *
     * @param pool     the buffer pool to normalize
     * @param maskPool pool containing a mask with zero/non-zero values.
     */
    private void logNormalize2DPool(Pool<Buffer[]> pool, Pool<float[][]> maskPool) {
        assert pool != null;
        for (int i = 0; i < pool.size(); i++) {
            Buffer[] bufferArray = pool.get(i);
            float[][] mask = maskPool.get(i);
            for (int j = 0; j < bufferArray.length; j++) {
                if (bufferArray[j].wasUsed()) {
                    bufferArray[j].logNormalizeNonZero(mask[j]);
                }
            }
        }
    }

    /** Update the models. */
    protected void update() {
        updateMeans();
        updateVariances();
        recomputeMixtureComponents();
        updateMixtureWeights();
        updateTransitionMatrices();
    }

    /**
     * Copy one vector onto another.
     *
     * @param in  the source vector
     * @param out the destination vector
     */
    private void copyVector(float[] in, float[] out) {
        assert in.length == out.length;
        System.arraycopy(in, 0, out, 0, in.length);
    }

    /** Update the means. */
    private void updateMeans() {
        assert meansPool.size() == meansBufferPool.size();
        for (int i = 0; i < meansPool.size(); i++) {
            float[] means = meansPool.get(i);
            Buffer buffer = meansBufferPool.get(i);
            if (buffer.wasUsed()) {
                float[] meansBuffer = buffer.getValues();
                copyVector(meansBuffer, means);
            } else {
                logger.info("Senone " + i + " not used.");
            }
        }
    }

    /** Update the variances. */
    private void updateVariances() {
        assert variancePool.size() == varianceBufferPool.size();
        for (int i = 0; i < variancePool.size(); i++) {
            float[] means = meansPool.get(i);
            float[] variance = variancePool.get(i);
            Buffer buffer = varianceBufferPool.get(i);
            if (buffer.wasUsed()) {
                float[] varianceBuffer = buffer.getValues();
                assert means.length == varianceBuffer.length;
                for (int j = 0; j < means.length; j++) {
                    varianceBuffer[j] -= means[j] * means[j];
                    if (varianceBuffer[j] < varianceFloor) {
                        varianceBuffer[j] = varianceFloor;
                    }
                }
                copyVector(varianceBuffer, variance);
            }
        }
    }

    /** Recompute the precomputed values in all mixture components. */
    private void recomputeMixtureComponents() {
        for (int i = 0; i < senonePool.size(); i++) {
            GaussianMixture gMix = (GaussianMixture) senonePool.get(i);
            MixtureComponent[] mixComponent = gMix.getMixtureComponents();
            for (MixtureComponent component : mixComponent) {
                component.precomputeDistance();
            }
        }
    }

    /** Update the mixture weights. */
    private void updateMixtureWeights() {
        int statesNum = mixtureWeights.getStatesNum();
        int streamsNum = mixtureWeights.getStreamsNum();
        assert statesNum * streamsNum == mixtureWeightsBufferPool.size();
        for (int i = 0; i < streamsNum; i++) {
            for (int j = 0; j < statesNum; j++) {
                int id = i * statesNum + j;
                Buffer buffer = mixtureWeightsBufferPool.get(id);
                if (buffer.wasUsed()) {
                    if (buffer.logFloor(logMixtureWeightFloor)) {
                        buffer.logNormalizeToSum(logMath);
                    }
                    float[] mixtureWeightsBuffer = buffer.getValues();
                    mixtureWeights.put(j, i, mixtureWeightsBuffer);
                }
            }
        }
    }

    /** Update the transition matrices. */
    private void updateTransitionMatrices() {
        assert matrixPool.size() == matrixBufferPool.size();
        for (int i = 0; i < matrixPool.size(); i++) {
            float[][] matrix = matrixPool.get(i);
            Buffer[] bufferArray = matrixBufferPool.get(i);
            for (int j = 0; j < matrix.length; j++) {
                Buffer buffer = bufferArray[j];
                if (buffer.wasUsed()) {
                    for (int k = 0; k < matrix[j].length; k++) {
                        float bufferValue = buffer.getValue(k);
                        if (bufferValue != LogMath.LOG_ZERO) {
                            assert matrix[j][k] != LogMath.LOG_ZERO;
                            if (bufferValue < logTransitionProbabilityFloor) {
                                buffer.setValue(k, logTransitionProbabilityFloor);
                            }
                        }
                    }
                    buffer.logNormalizeToSum(logMath);
                    copyVector(buffer.getValues(), matrix[j]);
                }
            }
        }
    }
}
